%% solveHJB.m
%
% Solve the HJB equation and generate policy functions
%--------------------------------------------------------------------------


%--------------------------------------------------------------------------
%% Preliminaries
if HJB_preSolved == 0
    dist      = zeros(maxit,1);
    disttmp   = zeros(Nr,Nr);

    dVf = zeros(Nb,Na,Nz);
    dVb = zeros(Nb,Na,Nz);
    c = zeros(Nb,Na,Nz);
    updiag = zeros(Nb*Na,Nz);
    lowdiag = zeros(Nb*Na,Nz);
    centdiag = zeros(Nb*Na,Nz);
    
    %Initial value function guess
    VAllr = zeros(Nb,Na,Nz,Nr);
    for rb_ind = 1:Nr
        v0 = (zzz*(1-taxRates(rb_ind)) + rbAll(1).*bbb + cc_wedge.*bbb_negative - rmortgageAll(1).*aaa).^(1-s)/(1-s)/rho;
        VAllr(:,:,:,rb_ind) = v0;
    end
    VAllr = repmat(VAllr,1,1,1,1,Nr);
    VAllr_tmp = zeros(Nb,Na,Nz,Nr,Nr);
        
    %Blanchard-Yaari Death Value Fxn (no taxes once retired so unit mass always paying)
    V_BY = (min(zstates)+rbdeath*(bbb+h-aaa)).^(1-s)/(rho*(1-s));
    V_BY_stacked = V_BY(:);
    v_stacked_death = la_BY*V_BY_stacked;
end


%--------------------------------------------------------------------------
%% SOLVE WITHOUT ADJUSTMENT
if UseNoAdjustmentAsGuess == 1 %solve without adjustment (for better initialization)
    for n=1:maxit
        for rb_ind = 1:Nr
            la_r = la_mat_r(rb_ind,:);
            rb = rbAll(rb_ind);
            tr = taxRates(rb_ind);
            for rm_ind = 1:Nr
                rmortgage=rmortgageAll(rm_ind);
                if ARMFlag == 1
                    %Under ARMS, tie mortgage rate to spot rate
                    rmortgage=rmortgageAll(rb_ind); 
                end
                V = VAllr(:,:,:,rb_ind,rm_ind);

                % forward difference
                dVf(1:Nb-1,:,:) = (V(2:Nb,:,:)-V(1:Nb-1,:,:))/db;
                dVf(Nb,:,:) = (zz_nob*(1-tr) + rb.*bmax + rb_avg_wedge(Nb)*bmax - (aa_nob.*(rmortgage+xi))).^(-s); %state constraint boundary condition
                % backward difference
                dVb(2:Nb,:,:) = (V(2:Nb,:,:)-V(1:Nb-1,:,:))/db;
                dVb(1,:,:) = (zz_nob*(1-tr) + rb.*bmin + rb_avg_wedge(1)*bmin - (aa_nob.*(rmortgage+xi))).^(-s); %state constraint boundary condition

                dVf = max(dVf,10^(-6));
                dVb = max(dVb,10^(-6));

                %consumption and savings with forward difference
                cf = dVf.^(-1/s);
                ssf = zzz*(1-tr) + rb*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - cf;
                Hf = cf.^(1-s)/(1-s) + dVf.*ssf;
                %consumption and savings with backward difference
                cb = dVb.^(-1/s);
                ssb = zzz*(1-tr) + rb*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - cb;
                Hb = cb.^(1-s)/(1-s) + dVb.*ssb;
                %consumption if not moving
                c0 = zzz*(1-tr) + rb*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi);

                %Upwind method makes choice of forward or backward differences 
                I0 = (ssf<=0) .* (ssb>=0);
                Iunique = (ssb<0).*(ssf<=0) + (ssb>=0).*(ssf>0);                                              
                Iboth = (ssb<0).*(ssf>0);
                Ib = Iunique.*(ssb<0) + Iboth.*(Hb>Hf); 
                If = Iunique.*(ssf>0) + Iboth.*(Hb<=Hf);

                c = cf.*If + cb.*Ib + c0.*I0;
                util = c.^(1-s)/(1-s);

                %Matrix BB summarizing evolution of b
                X = -Ib.*ssb/db;
                Y = -If.*ssf/db + Ib.*ssb/db;
                Z = If.*ssf/db;

                for i = 1:Nz
                    centdiag(:,i) = reshape(Y(:,:,i),Nb*Na,1);
                end

                lowdiag(1:Nb-1,:) = X(2:Nb,1,:);
                updiag(2:Nb,:) = Z(1:Nb-1,1,:);
                for j = 2:Na
                    lowdiag(1:j*Nb,:) = [lowdiag(1:(j-1)*Nb,:);squeeze(X(2:Nb,j,:));zeros(1,Nz)];
                    updiag(1:j*Nb,:) = [updiag(1:(j-1)*Nb,:);zeros(1,Nz);squeeze(Z(1:Nb-1,j,:))];
                end

                BB = sparse([]);
                for nz=1:Nz
                    BB = [BB; sparse(Nb*Na, Nb*Na*(nz-1)), ...
                        spdiags(centdiag(:,nz),0,Nb*Na,Nb*Na)+spdiags([updiag(:,nz);0],1,Nb*Na,Nb*Na)+spdiags([lowdiag(:,nz);0],-1,Nb*Na,Nb*Na), ...
                        sparse(Nb*Na, Nb*Na*(Nz-nz))];
                end

                A = AA + BB + ZZ;

                u_stacked = reshape(util,Nb*Na*Nz,1);
                V_stacked = reshape(V,Nb*Na*Nz,1);
                
                B = (rho + la_BY - la_r(rb_ind) + 1/Delta)*speye(Nb*Na*Nz) - A;
                v_stacked_rswitch = zeros(M,1);
                for i = [1:rb_ind-1,rb_ind+1:Nr]
                    v_stacked_rswitch=v_stacked_rswitch+la_r(i)*reshape(VAllr(:,:,:,i,rm_ind),M,1);
                end
                u_stacked_all = u_stacked + v_stacked_death + v_stacked_rswitch;
                vec = u_stacked_all + V_stacked/Delta;

                %Implicit updating
                V_stacked_2 = B\vec;

                V = reshape(V_stacked_2, Nb, Na, Nz);        
                disttmp(rb_ind,rm_ind) = max(max(max(abs(V - VAllr(:,:,:,rb_ind,rm_ind)))));
                VAllr_tmp(:,:,:,rb_ind,rm_ind) = V;
            end
        end
        VAllr = VAllr_tmp;
        dist(n) = max(max(disttmp));
        if dist(n)<tol
            disp(['Value Function Converged, Iteration = ' int2str(n)]);
            break
        else
            disp(['Initial guess, no adjustment, iter ', int2str(n), ', dist ' num2str(dist(n)) ]);
        end
    end
end


%--------------------------------------------------------------------------
%% SOLVE WITH ADJUSTMENT
% Outside of loop, generate adjustment target set
%   > Prepay: lower fixed cost, but no rate change and no cash-out option
%   > Refi: higher fixed cost, but rate change and cash-out option
aP_targets_prepay = zeros(NaP_prepay,Nb,Na);
aP_targets_rewrite = zeros(NaP_rewrite,Nb,Na);
bP_targets_prepay = zeros(NaP_prepay,Nb,Na);
bP_targets_rewrite = zeros(NaP_rewrite,Nb,Na);
cantAdjust = false(Nb,Na,Nz);
    cantAdjust((bbb-aaa)<(bmin-(thetam*h)+ka_rewrite)) = 1;
nw_pre = bbb(:,:,1)-aaa(:,:,1);

%Assign value function in "cant adjust" region so household won't adjust
%   How? A) Take lowest value from solution with no adjustment, or 
%        B) Assign value of being in worst part of state space in perpetuity
if UseNoAdjustmentAsGuess == 1 && sim_case ~= 0.51
    cantAdjustVal = min(VAllr,[],"all");
else
    cantAdjustVal = (zstates(1)*(1-taxRates(Nr)) + (rbAll(Nr)+cc_wedge)*bmin - (rmortgageAll(Nr)+xi).*amax).^(1-s)/(1-s)/rho;
end
    
%set la_forced = 0 in cant-adjust regions
la_forced_diag = la_forced.*ones(Nb,Na,Nz);
la_forced_diag(cantAdjust)=0;
la_forced_diag = reshape(la_forced_diag,Nb*Na*Nz,1);
la_forced_mat = spdiags(la_forced_diag,0,Nb*Na*Nz,Nb*Na*Nz);

for na = 1:Na
    a_curr = a(na);
    for nb = 1:Nb
        aPmin_prepay = max(0, bmin+ka_prepay-nw_pre(nb,na));
        aPmin_rewrite = max(0, bmin+ka_rewrite-nw_pre(nb,na)); 
        aPmax_rewrite = min(thetam*h, (bmax-.99*db)+ka_rewrite-nw_pre(nb,na));
        
        aP_rewrite = linspace(min(aPmin_rewrite,aPmax_rewrite),aPmax_rewrite,NaP_rewrite);
        if nocashoutFlag == 1
            %if no cashouts, still let households roll kappa into mortgage balance
            aP_rewrite = linspace(aPmin_rewrite, min(a_curr+ka_rewrite, amax), NaP_rewrite);
        end
        aP_prepay = linspace(min(aPmin_prepay,a_curr), a_curr, NaP_prepay);
        bP_rewrite = nw_pre(nb,na) - ka_rewrite + aP_rewrite;
        bP_prepay = nw_pre(nb,na) - ka_prepay + aP_prepay;
        
        aP_targets_rewrite(:,nb,na) = aP_rewrite;
        aP_targets_prepay(:,nb,na) = aP_prepay;
        bP_targets_rewrite(:,nb,na) = bP_rewrite;
        bP_targets_prepay(:,nb,na) = bP_prepay;
    end
end

% Iterate again on HJB equation, but now allowing for adjustment
finalIterations = 0;
for n=1:maxit
    for rb_ind = 1:Nr
        la_r = la_mat_r(rb_ind,:);
        rb = rbAll(rb_ind);
        tr = taxRates(rb_ind);
        for rm_ind = 1:Nr
            rmortgage=rmortgageAll(rm_ind);
            if ARMFlag == 1
                %Under ARMS, tie mortgage rate to spot rate
                rmortgage=rmortgageAll(rb_ind); 
            end
            V = VAllr(:,:,:,rb_ind,rm_ind);

            % forward difference
            dVf(1:Nb-1,:,:) = (V(2:Nb,:,:)-V(1:Nb-1,:,:))/db;
            dVf(Nb,:,:) = (zz_nob*(1-tr) + rb.*bmax + rb_avg_wedge(Nb)*bmax - (aa_nob.*(rmortgage+xi))).^(-s); %state constraint boundary condition
            % backward difference
            dVb(2:Nb,:,:) = (V(2:Nb,:,:)-V(1:Nb-1,:,:))/db;
            dVb(1,:,:) = (zz_nob*(1-tr) + rb.*bmin + rb_avg_wedge(1)*bmin - (aa_nob.*(rmortgage+xi))).^(-s); %state constraint boundary condition

            dVf = max(dVf,10^(-6));
            dVb = max(dVb,10^(-6));

            %consumption and savings with forward difference
            cf = dVf.^(-1/s);
            ssf = zzz*(1-tr) + rb*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - cf;
            Hf = cf.^(1-s)/(1-s) + dVf.*ssf;
            %consumption and savings with backward difference
            cb = dVb.^(-1/s);
            ssb = zzz*(1-tr) + rb*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - cb;
            Hb = cb.^(1-s)/(1-s) + dVb.*ssb;
            %consumption if not moving
            c0 = zzz*(1-tr) + rb*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi);

            %Upwind method makes choice of forward or backward differences 
            I0 = (ssf<=0) .* (ssb>=0);
            Iunique = (ssb<0).*(ssf<=0) + (ssb>=0).*(ssf>0);                                              
            Iboth = (ssb<0).*(ssf>0);
            Ib = Iunique.*(ssb<0) + Iboth.*(Hb>Hf); 
            If = Iunique.*(ssf>0) + Iboth.*(Hb<=Hf);

            c = cf.*If + cb.*Ib + c0.*I0;
            util = c.^(1-s)/(1-s);

            %Matrix BB summarizing evolution of b
            X = -Ib.*ssb/db;
            Y = -If.*ssf/db + Ib.*ssb/db;
            Z = If.*ssf/db;

            for i = 1:Nz
                centdiag(:,i) = reshape(Y(:,:,i),Nb*Na,1);
            end

            lowdiag(1:Nb-1,:) = X(2:Nb,1,:);
            updiag(2:Nb,:) = Z(1:Nb-1,1,:);
            for j = 2:Na
                lowdiag(1:j*Nb,:) = [lowdiag(1:(j-1)*Nb,:);squeeze(X(2:Nb,j,:));zeros(1,Nz)];
                updiag(1:j*Nb,:) = [updiag(1:(j-1)*Nb,:);zeros(1,Nz);squeeze(Z(1:Nb-1,j,:))];
            end

            BB = sparse([]);
            for nz=1:Nz
                BB = [BB; sparse(Nb*Na, Nb*Na*(nz-1)), ...
                    spdiags(centdiag(:,nz),0,Nb*Na,Nb*Na)+spdiags([updiag(:,nz);0],1,Nb*Na,Nb*Na)+spdiags([lowdiag(:,nz);0],-1,Nb*Na,Nb*Na), ...
                    sparse(Nb*Na, Nb*Na*(Nz-nz))];
            end

            A = AA + BB + ZZ;

            %--------------------------------------------------------------
            % Adjustment decision for x grid: this constructs objects on predetermined grid for x, then interpolates to (b, a)
            vstarAux = zeros(Nb, Nz);
            bAdjAux  = zeros(Nz,Nb,Na);
            aAdjAux  = zeros(Nz,Nb,Na);
            rewriteAux = zeros(Nz,Nb,Na);
            vAdj_rewrite = zeros(NaP_rewrite, Nz);
            vAdj_prepay = zeros(NaP_prepay, Nz);

            %IMPORTANT: If re-writing, set rm-index to current rb index 
            %           If prepaying, keep rm-index the same
            G_rewrite = griddedInterpolant(bbb,aaa,zzz,VAllr(:,:,:,rb_ind,rb_ind));
            G_prepay = griddedInterpolant(bbb,aaa,zzz,VAllr(:,:,:,rb_ind,rm_ind));
            
            vAdj_rewrite_all = zeros(Nz,NaP_rewrite,Nb,Na);
            vAdj_prepay_all = zeros(Nz,NaP_prepay,Nb,Na);
            for nz = 1:Nz
                vAdj_rewrite_all(nz,:,:,:) = G_rewrite(bP_targets_rewrite, aP_targets_rewrite, zstates(nz).*ones(NaP_rewrite,Nb,Na)); 
                vAdj_prepay_all(nz,:,:,:) = G_prepay(bP_targets_prepay, aP_targets_prepay, zstates(nz).*ones(NaP_prepay,Nb,Na));
            end
            
            Vstar = zeros(Nz,Nb,Na);
            for na = 1:Na
                for nb = 1:Nb
                    %Option 1: Rewrite Mortgage 
                    aP_rewrite = aP_targets_rewrite(:,nb,na);
                    bP_rewrite = bP_targets_rewrite(:,nb,na);
                    vAdj_rewrite = vAdj_rewrite_all(:,:,nb,na)';
                        %adjust to make sure no off-grid values chosen
                        vAdj_rewrite(bP_rewrite<bmin-1e-10,:) = cantAdjustVal;
                    [vmax_rewrite,idx_rewrite] = max(vAdj_rewrite); %find max and argmax
                    %Option 2: Prepay Mortgage  
                    aP_prepay = aP_targets_prepay(:,nb,na);
                    bP_prepay = bP_targets_prepay(:,nb,na);
                    vAdj_prepay = vAdj_prepay_all(:,:,nb,na)';
                        %adjust to make sure no off-grid values chosen
                        vAdj_prepay(bP_prepay<bmin-1e-10,:) = cantAdjustVal;
                    [vmax_prepay, idx_prepay] = max(vAdj_prepay); %find max and argmax
                    %Combine 
                    use_rw = (vmax_rewrite>vmax_prepay);
                    Vstar(:,nb,na) = vmax_rewrite.*(use_rw) + vmax_prepay.*(1-use_rw);
                    aAdjAux(:,nb,na) = aP_rewrite(idx_rewrite)'.*(use_rw) + aP_prepay(idx_prepay)'.*(1-use_rw);
                    bAdjAux(:,nb,na) = bP_rewrite(idx_rewrite)'.*(use_rw) + bP_prepay(idx_prepay)'.*(1-use_rw);
                    rewriteAux(:,nb,na) = use_rw; 
                end                
            end
            Vstar = permute(Vstar,[2,3,1]);
            Vstar(cantAdjust) = cantAdjustVal;
            Vstar = Vstar(:);
        
            
            %--------------------------------------------------------------
            if rational_slow_refi ~= 1
                % Solve using LCP (instant adjustment opportunities)
                u_stacked = reshape(util,Nb*Na*Nz,1);
                V_stacked = reshape(V,Nb*Na*Nz,1);

                B = (rho + la_BY - la_r(rb_ind) + 1/Delta)*speye(Nb*Na*Nz) + la_forced_mat - A;
                v_stacked_forced = la_forced_diag.*Vstar(:);
                v_stacked_rswitch = zeros(M,1);
                for i = [1:rb_ind-1,rb_ind+1:Nr]
                    v_stacked_rswitch=v_stacked_rswitch+la_r(i)*reshape(VAllr(:,:,:,i,rm_ind),M,1);
                end
                u_stacked_all = u_stacked + v_stacked_death + v_stacked_forced + v_stacked_rswitch;
                vec = u_stacked_all + V_stacked/Delta;

                q = -vec + B*Vstar;

                %using Yuval Tassa's Newton-based LCP solver, download from http://www.mathworks.com/matlabcentral/fileexchange/20952
                z0 = V_stacked-Vstar; lbnd = zeros(Nb*Na*Nz,1); ubnd = Inf*ones(Nb*Na*Nz,1);
                z = LCP(B,q,lbnd,ubnd,z0,0);

                LCP_error = max(abs(z.*(B*z + q))); % check the accuracy of the LCP solution
                if LCP_error > 10^(-6)
                    disp('LCP not solved')
                    keepIteratingLCP = 1;
                end
                if min(z(cantAdjust(:))) <= 10^(-6)
                    disp('Adjustment occurring in non-adjustment region');
                    keepIteratingLCP = 1;
                end

                V_stacked = z+Vstar; %update value function
  
            elseif rational_slow_refi == 1
                % Solve with slow arrival of adjustment opportunities (arrive at rate la_slowrefi)
                u_stacked = reshape(util,Nb*Na*Nz,1);
                V_stacked = reshape(V,Nb*Na*Nz,1);
                V_optionalAdjust = max(Vstar, V_stacked);

                B = (rho + la_BY + la_slowrefi - la_r(rb_ind) + 1/Delta)*speye(Nb*Na*Nz) + la_forced_mat - A;
                v_stacked_adjust = la_slowrefi*V_optionalAdjust;
                v_stacked_forced = la_forced_diag.*Vstar(:);
                v_stacked_rswitch = zeros(M,1);
                for i = [1:rb_ind-1,rb_ind+1:Nr]
                    v_stacked_rswitch=v_stacked_rswitch+la_r(i)*reshape(VAllr(:,:,:,i,rm_ind),M,1);
                end
                u_stacked_all = u_stacked + v_stacked_death + v_stacked_adjust + v_stacked_forced + v_stacked_rswitch;
                vec = u_stacked_all + V_stacked/Delta;

                %Implicit updating
                V_stacked = B\vec;
            end
            V = reshape(V_stacked, Nb, Na, Nz);
            disttmp(rb_ind,rm_ind) = max(max(max(abs(V - VAllr(:,:,:,rb_ind,rm_ind)))));
            VAllr_tmp(:,:,:,rb_ind,rm_ind) = V;
            
            
            %--------------------------------------------------------------
            % Store policy functions if converged
            if finalIterations == 1
                %Adjustment targets 
                if rational_slow_refi ~= 1
                    adjStore{rb_ind,rm_ind} = (abs(V_stacked - Vstar)<10^(-6)); % allow for some tolerance in defining adjustment decision
                elseif rational_slow_refi == 1
                    adjStore{rb_ind,rm_ind} = (V_optionalAdjust==Vstar);
                end
                bAdj_combineStore{rb_ind,rm_ind} = reshape(permute(bAdjAux,[2,3,1]),Nb*Na,Nz);
                aAdj_combineStore{rb_ind,rm_ind} = reshape(permute(aAdjAux,[2,3,1]),Nb*Na,Nz);
                rewrite_combineStore{rb_ind,rm_ind} = reshape(permute(rewriteAux,[2,3,1]),Nb*Na,Nz);

                %Consumption and saving
                if beta == 1
                    sssStore{rb_ind,rm_ind} = zzz*(1-tr) + rb.*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - c;
                    cStore{rb_ind,rm_ind} = c;
                elseif beta < 1
                    sss_expectedStore{rb_ind,rm_ind} = zzz*(1-tr) + rb.*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - c;
                    c_nbd = beta^(-1/s).*c; %consumption of naif
                        c_nbd(1, :, :) = min(c_nbd(1,:,:),c0(1,:,:)); %account for constraint
                    sss = zzz*(1-tr) + rb.*bbb + rb_avg_wedge_all.*bbb - aaa.*(rmortgage+xi) - c_nbd;
                    sssStore{rb_ind,rm_ind} = sss;
                    cStore{rb_ind,rm_ind} = c_nbd;

                    %Update Transition Matrix
                    Ib = sss<0; 
                    If = sss>0;
                    I0 = 1 - If - Ib;
                    X = -Ib.*sss/db;
                    Y = -If.*sss/db + Ib.*sss/db;
                    Z = If.*sss/db;

                    for i = 1:Nz
                        centdiag(:,i) = reshape(Y(:,:,i),Nb*Na,1);
                    end

                    lowdiag(1:Nb-1,:) = X(2:Nb,1,:);
                    updiag(2:Nb,:) = Z(1:Nb-1,1,:);
                    for j = 2:Na
                        lowdiag(1:j*Nb,:) = [lowdiag(1:(j-1)*Nb,:);squeeze(X(2:Nb,j,:));zeros(1,Nz)];
                        updiag(1:j*Nb,:) = [updiag(1:(j-1)*Nb,:);zeros(1,Nz);squeeze(Z(1:Nb-1,j,:))];
                    end

                    BB = sparse([]);
                    for nz=1:Nz
                        BB = [BB; sparse(Nb*Na, Nb*Na*(nz-1)), ...
                            spdiags(centdiag(:,nz),0,Nb*Na,Nb*Na)+spdiags([updiag(:,nz);0],1,Nb*Na,Nb*Na)+spdiags([lowdiag(:,nz);0],-1,Nb*Na,Nb*Na), ...
                            sparse(Nb*Na, Nb*Na*(Nz-nz))];
                    end
                    A = AA + BB + ZZ;
                else
                    error('Define beta<=1');
                end
                AStore{rb_ind,rm_ind} = A;
            end
        end
    end
    
    %Update value function all at once at end of iteration
    VAllr = VAllr_tmp;
    dist(n) = max(max(disttmp));
    if dist(n)<tol & finalIterations == 1 & keepIteratingLCP == 0
        disp(['Value Function Converged, Iteration = ' int2str(n)]);
        break
    elseif dist(n)<tol & finalIterations == 0
        disp(['With adjustment, adjustment, iter ', int2str(n), ', dist ' num2str(dist(n)) ]);
        finalIterations = 1;
    else
        disp(['With adjustment, adjustment, iter ', int2str(n), ', dist ' num2str(dist(n)) ]);
    end
    keepIteratingLCP = 0;
end


%Delete some variables to save memory
clearvars G_rewrite G_prepay vAdj_rewrite_all vAdj_prepay_all;

